import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from scipy.io import loadmat
from tqdm import tqdm
import json

np.random.seed(0)
torch.manual_seed(0)


# rnn = 'identity_rnn_relu'
# rnn = 'gaussian_rnn_relu'
rnn = 'rnn_tanh'
# rnn = 'lstm'
# rnn = 'gru'

n_hiddens = [50, 100, 200, 500]
n_input = 1
n_layers = [1,2]

n_fold = 5 # cross validation 
bs=50
max_ep = 200
threshold = 1e-6
lrs = [0.01, 0.001, 0.0001, 0.00001]

results = {}
best_lr = {}
best_n_layer = {}
best_n_hidden = {}

data_root_dir = 'datasets'
data_dirs = os. listdir(data_root_dir)
for i in range(len(data_dirs)):
    if 'DS_Store' in data_dirs[i]:
        del data_dirs[i]
        break
data_dirs.sort()

for data_dir in tqdm(data_dirs):

    print('-----------working on dataset: {}-----------'.format(data_dir))
    
    # load data
    data_path = os.path.join(data_root_dir, data_dir)
    X_train = loadmat(os.path.join(data_path, 'X_train_raw.mat'))['data']
    X_test = loadmat(os.path.join(data_path, 'X_test_raw.mat'))['data']
    y_train = loadmat(os.path.join(data_path, 'y_train_raw.mat'))['data'].squeeze()
    y_test = loadmat(os.path.join(data_path, 'y_test_raw.mat'))['data'].squeeze()

    # remap train and test set
    y_map = {}
    unique_y = np.sort(np.unique(np.concatenate((y_train, y_test))))
    for i in range(len(unique_y)):
        y_map[unique_y[i]] = i
    for i in range(len(y_train)):
        y_train[i] = y_map[y_train[i]]
    for i in range(len(y_test)):
        y_test[i] = y_map[y_test[i]]
        
    # get cv folds
    indices_total = np.array(list(range(X_train.shape[0])))
    cv_size = int(np.ceil(X_train.shape[0]/5.))
    
    # start experiment
    best_cv_loss = 1000000
    for lr in lrs:
        for n_layer in n_layers:
            for n_hidden in n_hiddens:
                
#                 best_cv_n_layer = 0
#                 best_cv_n_hidden = 0
#                 best_cv_lr = 0
                avg_cv_loss = 0
                
                # cross validation 
                for cv in range(n_fold):
                    
                    # get cv data split
                    indices_cv = indices_total[cv*cv_size : min(cv*cv_size + cv_size, X_train.shape[0])]
                    X_valid_cv = X_train[indices_cv,:]
                    y_valid_cv = y_train[indices_cv]
                    indices_cv_train = np.setdiff1d(indices_total, indices_cv)
                    assert(np.array_equal(np.sort(indices_total),np.sort(np.concatenate((indices_cv_train, indices_cv)))))
                    assert(len(indices_cv_train) > len(indices_cv))
                    X_train_cv = X_train[indices_cv_train,:]
                    y_train_cv = y_train[indices_cv_train]
                
                    # load model
                    if rnn == 'identity_rnn_relu':
                        model = nn.RNN(input_size=n_input, hidden_size=n_hidden, num_layers=n_layer, nonlinearity='relu')
                        model.weight_hh_l0.data = torch.eye(n_hidden)
                        model.bias_hh_l0.data = torch.zeros_like(model.bias_hh_l0)
                        model.bias_ih_l0.data = torch.zeros_like(model.bias_ih_l0)
                    elif rnn == 'gaussian_rnn_relu':
                        model = nn.RNN(input_size=n_input, hidden_size=n_hidden, num_layers=n_layer, nonlinearity='relu')
                        model.weight_hh_l0.data = torch.randn_like(model.weight_hh_l0) / np.sqrt(n_hidden) * np.sqrt(2)
                        model.weight_ih_l0.data = torch.randn_like(model.weight_ih_l0)
                        model.bias_hh_l0.data = torch.zeros_like(model.bias_hh_l0)
                        model.bias_ih_l0.data = torch.zeros_like(model.bias_ih_l0)
                    elif rnn == 'rnn_tanh':
                        model = nn.RNN(input_size=n_input, hidden_size=n_hidden, num_layers=n_layer)
                    elif rnn == 'gru':
                        model = nn.GRU(input_size=n_input, hidden_size=n_hidden, num_layers=n_layer)
                    elif rnn == 'LSTM':
                        model = nn.LSTM(input_size=n_input, hidden_size=n_hidden, num_layers=n_layer)
                    model.cuda()
                    V = nn.Parameter(torch.zeros(n_hidden, len(unique_y)))
                    V.data = torch.randn_like(V) / np.sqrt(n_hidden) * np.sqrt(2)
                    V = V.cuda()
                    m = nn.LogSoftmax(dim=1)

                    # set optim
                    optimizer = optim.RMSprop(model.parameters(), lr=lr)

                    # set loss
                    criterion = nn.NLLLoss(reduction='sum')

                    prev_loss = -1000
                    # start training
                    for ep in range(max_ep):
                        model.train()

                        # permute data
                        perm = np.random.choice(range(X_train_cv.shape[0]), X_train_cv.shape[0], replace=False)
                        X_train_cv = X_train_cv[perm,:]
                        y_train_cv = y_train_cv[perm]

                        # feed data to model
                        n_it = int(np.ceil(X_train_cv.shape[0] / bs))
                        indices = np.array(list(range(X_train_cv.shape[0])))
                        pred = np.array([])
                        loss_train = 0
                        for it in range(n_it):
                            model.zero_grad()
                            optimizer.zero_grad()
                            batch_indices = indices[it*bs : min(it*bs + bs, X_train_cv.shape[0])]
                            batch_X_train = torch.FloatTensor(X_train_cv[batch_indices]).t().unsqueeze(2).cuda()
                            _, h_n = model(batch_X_train)
                            out = m(torch.mm(h_n[-1], V))
                            loss = criterion(out, torch.LongTensor(y_train_cv[batch_indices]).cuda()) / batch_X_train.shape[0]
                            loss_train += loss * batch_X_train.shape[0]
                            # backward
                            loss.backward()
                            optimizer.step()

                        # stop criterion: convergence of loss
                        if torch.abs(loss_train / X_train_cv.shape[0] - prev_loss) < threshold:
                            break
                        prev_loss = loss_train / X_train_cv.shape[0] 

                    # validation
                    model.eval()
                    # feed data to model
                    n_it = int(np.ceil(X_valid_cv.shape[0] / bs))
                    indices = np.array(list(range(X_valid_cv.shape[0])))
                    pred = np.array([])
                    loss_valid = 0
                    for it in range(n_it):
                        batch_indices = indices[it*bs : min(it*bs + bs, X_valid_cv.shape[0])]
                        batch_X_valid = torch.FloatTensor(X_valid_cv[batch_indices]).t().unsqueeze(2).cuda()
                        _, h_n = model(batch_X_valid)
                        out = m(torch.mm(h_n[-1], V))
                        loss = criterion(out, torch.LongTensor(y_valid_cv[batch_indices]).cuda())
                        loss_valid += loss
                    loss_valid /= X_valid_cv.shape[0]
                    avg_cv_loss += loss_valid
                    
                    del model
                    del X_train_cv
                    del y_train_cv
                    del X_valid_cv
                    del y_valid_cv
                    del criterion
                    del optimizer
                    del V
                    torch.cuda.empty_cache()
                    
                # record best cross validation configurations
                avg_cv_loss /= n_fold
#                 print(avg_cv_loss)
                if avg_cv_loss < best_cv_loss:
#                     print('best!')
#                     print('best lr={}, n hidden={}, n layer={}'.format(lr, n_hidden, n_layer))
                    best_cv_loss = avg_cv_loss
                    best_cv_lr = lr
                    best_cv_n_hidden = n_hidden
                    best_cv_n_layer = n_layer
#             print('----')
                
    # training the model with the best cross validated hyper-parameters
    # load model
    if rnn == 'identity_rnn_relu':
        model = nn.RNN(input_size=n_input, hidden_size=best_cv_n_hidden, 
                       num_layers=best_cv_n_layer, nonlinearity='relu')
        model.weight_hh_l0.data = torch.eye(best_cv_n_hidden)
        model.bias_hh_l0.data = torch.zeros_like(model.bias_hh_l0)
        model.bias_ih_l0.data = torch.zeros_like(model.bias_ih_l0)
    elif rnn == 'gaussian_rnn_relu':
        model = nn.RNN(input_size=n_input, hidden_size=best_cv_n_hidden, 
                       num_layers=best_cv_n_layer, nonlinearity='relu')
        model.weight_hh_l0.data = torch.randn_like(model.weight_hh_l0) / np.sqrt(best_cv_n_hidden) * np.sqrt(2)
        model.weight_ih_l0.data = torch.randn_like(model.weight_ih_l0)
        model.bias_hh_l0.data = torch.zeros_like(model.bias_hh_l0)
        model.bias_ih_l0.data = torch.zeros_like(model.bias_ih_l0)
    elif rnn == 'rnn_tanh':
        model = nn.RNN(input_size=n_input, hidden_size=best_cv_n_hidden, num_layers=best_cv_n_layer)
    elif rnn == 'gru':
        model = nn.GRU(input_size=n_input, hidden_size=best_cv_n_hidden, num_layers=best_cv_n_layer)
    elif rnn == 'LSTM':
        model = nn.LSTM(input_size=n_input, hidden_size=best_cv_n_hidden, num_layers=best_cv_n_layer)
    model.cuda()
    V = nn.Parameter(torch.zeros(best_cv_n_hidden, len(unique_y)))
    V.data = torch.randn_like(V) / np.sqrt(best_cv_n_hidden) * np.sqrt(2)
    V = V.cuda()
    m = nn.LogSoftmax(dim=1)

    # set optim
    optimizer = optim.RMSprop(model.parameters(), lr=best_cv_lr)

    # set loss
    criterion = nn.NLLLoss(reduction='sum')

    prev_loss = -1000
    # start training
    for ep in range(max_ep):
        model.train()

        # permute data
        perm = np.random.choice(range(X_train.shape[0]), X_train.shape[0], replace=False)
        X_train = X_train[perm,:]
        y_train = y_train[perm]

        # feed data to model
        n_it = int(np.ceil(X_train.shape[0] / bs))
        indices = np.array(list(range(X_train.shape[0])))
        pred = np.array([])
        loss_train = 0
        for it in range(n_it):
            model.zero_grad()
            optimizer.zero_grad()
            batch_indices = indices[it*bs : min(it*bs + bs, X_train.shape[0])]
            batch_X_train = torch.FloatTensor(X_train[batch_indices]).t().unsqueeze(2).cuda()
            _, h_n = model(batch_X_train)
            out = m(torch.mm(h_n[-1], V))
            pred = np.concatenate((pred, torch.argmax(out, dim=1).detach().cpu().numpy()))
            loss = criterion(out, torch.LongTensor(y_train[batch_indices]).cuda()) / batch_X_train.shape[0]
            loss_train += loss * batch_X_train.shape[0]
            # backward
            loss.backward()
            optimizer.step()

        # stop criterion
        if torch.abs(loss_train / X_train.shape[0] - prev_loss) < threshold:
            break
        prev_loss = loss_train / X_train.shape[0]

    # final training acc
    acc_train = np.sum(pred == y_train) / len(y_train)

    # evaluation
    model.eval()
    # feed data to model
    n_it = int(np.ceil(X_test.shape[0] / bs))
    indices = np.array(list(range(X_test.shape[0])))
    pred = np.array([])
    for it in range(n_it):
        batch_indices = indices[it*bs : min(it*bs + bs, X_test.shape[0])]
        batch_X_test = torch.FloatTensor(X_test[batch_indices]).t().unsqueeze(2).cuda()
        _, h_n = model(batch_X_test)
        out = m(torch.mm(h_n[-1], V))
        pred = np.concatenate((pred, torch.argmax(out, dim=1).detach().cpu().numpy()))
    acc = np.sum(pred == y_test) / len(y_test)

    # log results
    results[data_dir] = acc
    best_lr[data_dir] = best_cv_lr
    best_n_hidden[data_dir] = best_cv_n_hidden
    best_n_layer[data_dir] = best_cv_n_layer
    log_str = 'dataset = {}\n'.format(data_dir) + \
                '      random acc = {}\n'.format(1./len(unique_y)) + \
                '      train acc = {}\n'.format(acc_train) + \
                '      test acc = {}\n'.format(acc) + \
                '      best lr = {}\n'.format(best_cv_lr) + \
                '      best n layer = {}\n'.format(best_cv_n_layer) + \
                '      best n hidden = {}\n'.format(best_cv_n_hidden) + \
                '      size of training data = {}, length={}\n'.format(X_train.shape[0], X_train.shape[1]) + \
                '      size of test data = {}, length={}\n'.format(X_test.shape[0], X_test.shape[1]) + \
                '      num classes = {}\n\n\n'.format(len(unique_y)) 
    print(log_str)
    
    with open(os.path.join('results', 'logs-{}.txt'.format(rnn)), 'a+') as output_file:
        output_file.write(log_str)

    js = json.dumps(results)
    with open(os.path.join('results', 'scores-{}.json'.format(rnn)), 'w') as output_file:
        output_file.write(js)

    js = json.dumps(best_lr)
    with open(os.path.join('results', 'best_lr-{}.json'.format(rnn)), 'w') as output_file:
        output_file.write(js)

    js = json.dumps(best_n_layer)
    with open(os.path.join('results', 'best_n_layer-{}.json'.format(rnn)), 'w') as output_file:
        output_file.write(js)

    js = json.dumps(best_n_hidden)
    with open(os.path.join('results', 'best_n_hidden-{}.json'.format(rnn)), 'w') as output_file:
        output_file.write(js)

    
        
    
    